{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[View the runnable example on GitHub](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/inference/pytorch/quantize_pytorch_inference_inc.ipynb)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Quantize PyTorch Model in INT8 for Inference using Intel Neural Compressor"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With Intel Neural Compressor (INC) as quantization engine, you can apply `InferenceOptimizer.quantize` API to realize INT8 post-training quantization on your PyTorch `nn.Module`. `InferenceOptimizer.quantize` also supports ONNXRuntime acceleration at the meantime through specifying `accelerator='onnxruntime'`. All acceleration takes only a few lines."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "To quantize your model with INC, you need to install BigDL-Nano for PyTorch inference first:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "!pip install --pre --upgrade bigdl-nano[pytorch,inference] # install the nightly-built version\n",
    "!source bigdl-nano-init"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> We recommend to run the commands above, especially `source bigdl-nano-init` before jupyter kernel is started, or some of the optimizations may not take effect."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take an [ResNet-18 model](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) pretrained on ImageNet dataset and finetuned on [OxfordIIITPet dataset](https://pytorch.org/vision/main/generated/torchvision.datasets.OxfordIIITPet.html) as an example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Define the finetune function\n",
    "import torch\n",
    "from torchvision import transforms\n",
    "from torchvision.datasets import OxfordIIITPet\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from torchvision.models import resnet18\n",
    "from bigdl.nano.pytorch import Trainer\n",
    "from torchmetrics.classification import MulticlassAccuracy\n",
    "\n",
    "\n",
    "def finetune_pet_dataset(model_ft):\n",
    "\n",
    "    train_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                          transforms.RandomCrop(224),\n",
    "                                          transforms.RandomHorizontalFlip(),\n",
    "                                          transforms.ColorJitter(brightness=.5, hue=.3),\n",
    "                                          transforms.ToTensor(),\n",
    "                                          transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                               [0.229, 0.224, 0.225])])\n",
    "    val_transform = transforms.Compose([transforms.Resize(256),\n",
    "                                        transforms.CenterCrop(224),\n",
    "                                        transforms.ToTensor(),\n",
    "                                        transforms.Normalize([0.485, 0.456, 0.406],\n",
    "                                                             [0.229, 0.224, 0.225])])\n",
    "\n",
    "    # apply data augmentation to the tarin_dataset\n",
    "    train_dataset = OxfordIIITPet(root=\"/tmp/data\",\n",
    "                                  transform=train_transform,\n",
    "                                  download=True)\n",
    "    val_dataset = OxfordIIITPet(root=\"/tmp/data\",\n",
    "                                transform=val_transform)\n",
    "\n",
    "    # obtain training indices that will be used for validation\n",
    "    indices = torch.randperm(len(train_dataset))\n",
    "    val_size = len(train_dataset) // 4\n",
    "    train_dataset = torch.utils.data.Subset(train_dataset, indices[:-val_size])\n",
    "    val_dataset = torch.utils.data.Subset(val_dataset, indices[-val_size:])\n",
    "\n",
    "    # prepare data loaders\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=32)\n",
    "    val_dataloader = DataLoader(val_dataset, batch_size=32)\n",
    "\n",
    "    num_ftrs = model_ft.fc.in_features\n",
    "\n",
    "    # here the size of each output sample is set to 37.\n",
    "    model_ft.fc = torch.nn.Linear(num_ftrs, 37)\n",
    "    loss_ft = torch.nn.CrossEntropyLoss()\n",
    "    optimizer_ft = torch.optim.SGD(model_ft.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "    # compile our model with loss function, optimizer.\n",
    "    model = Trainer.compile(model_ft, loss_ft, optimizer_ft, metrics=[MulticlassAccuracy(num_classes=37)])\n",
    "    trainer = Trainer(max_epochs=1)\n",
    "    trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)\n",
    "\n",
    "    return model, train_dataset, val_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "\n",
    "model = resnet18(pretrained=True)\n",
    "_, train_dataset, val_dataset = finetune_pet_dataset(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;_The full definition of function_ `finetune_pet_dataset` _could be found in the_ [runnable example](https://github.com/intel-analytics/BigDL/tree/main/python/nano/tutorial/notebook/inference/pytorch/quantize_pytorch_inference_inc.ipynb)."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To enable INT8 quantization using INC for inference, you could simply **import BigDL-Nano** `InferenceOptimizer`**, and use** `InferenceOptimizer` **to quantize your PyTorch model**:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import InferenceOptimizer\n",
    "\n",
    "q_model = InferenceOptimizer.quantize(model, \n",
    "                                      calib_data=DataLoader(train_dataset, batch_size=32))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you want to enable the ONNXRuntime acceleration at the meantime, you could just specify the `accelerator` parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigdl.nano.pytorch import InferenceOptimizer\n",
    "\n",
    "q_model = InferenceOptimizer.quantize(model,\n",
    "                                      accelerator='onnxruntime',\n",
    "                                      calib_data=DataLoader(train_dataset, batch_size=32))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> The `InferenceOptimizer.quantize` function has a `precision` parameter to specify the precision for quantization. It is default to be `'int8'`. So, we omit the `precision` parameter here for INT8 quantization.\n",
    ">\n",
    "> During INT8 quantization using INC, `InferenceOptimizer` will by default quantize your PyTorch `nn.Module` through **static** post-training quantization. For this case, `calib_data` (for calibration data) is required. Batch size is not important to ``calib_data``, as it intends to read 100 samples. And there could be no label in calibration data.\n",
    "> \n",
    "> If you would like to implement dynamic post-training quantization, you could set parameter `approach='dynamic'`. In this case, `calib_dataloader` should be `None`. Compared to dynamic quantization, static quantization could lead to faster inference as it eliminates the data conversion costs between layers.\n",
    "> \n",
    "> Please refer to [API documentation](https://bigdl.readthedocs.io/en/latest/doc/PythonAPI/Nano/pytorch.html#bigdl.nano.pytorch.InferenceOptimizer.quantize) for more information on `InferenceOptimizer.quantize`."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You could then do the normal inference steps **under the context manager provided by Nano**, with the quantized model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with InferenceOptimizer.get_context(q_model):\n",
    "    x = torch.stack([val_dataset[0][0], val_dataset[1][0]])\n",
    "    # use the quantized model here\n",
    "    y_hat = q_model(x)\n",
    "    predictions = y_hat.argmax(dim=1)\n",
    "    print(predictions)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📝 **Note**\n",
    "> \n",
    "> For all Nano optimized models by `InferenceOptimizer.quantize`, you need to wrap the inference steps with an automatic context manager `InferenceOptimizer.get_context(model=...)` provided by Nano. You could refer to [here](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_context_manager.html) for more detailed usage of the context manager."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> 📚 **Related Readings**\n",
    "> \n",
    "> - [How to install BigDL-Nano](https://bigdl.readthedocs.io/en/latest/doc/Nano/Overview/install.html)\n",
    "> - [How to enable automatic context management for PyTorch inference on Nano optimized models](https://bigdl.readthedocs.io/en/latest/doc/Nano/Howto/Inference/PyTorch/pytorch_context_manager.html)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 ('nano-pytorch': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "09344c7f3239fd422839751f876786d6b1a624c40f19af1b43cb2737f421c2b2"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
